{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch_geometric\n",
    "import deepdish as dd\n",
    "from tqdm import tqdm\n",
    "import copy\n",
    "from os.path import join\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import ChemicalFeatures, AllChem\n",
    "from rdkit import RDConfig\n",
    "from rdkit.Chem.rdmolfiles import MolFromMolFile\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.cluster import AgglomerativeClustering\n",
    "import random\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Sequential, Linear, ReLU, GRU, BatchNorm1d, Dropout\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.nn import  GCNConv, global_add_pool, global_mean_pool,GATConv,GINConv\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.utils import remove_self_loops\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.data import InMemoryDataset \n",
    "from Models.graphLambda import Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def add_edges_list(root, code):\n",
    "    '''\n",
    "    This function creates the graph that represents a molecule. \n",
    "    Input: root: root directory of hydrogen-free sdf files of ligands 'str'.\n",
    "           code: PDB code of the complex containing the ligand 'str'. \n",
    "    Output: edges list 'numpy array'\n",
    "            \n",
    "    '''\n",
    "    ligand_filename = code + \"_h.sdf\"\n",
    "    m = MolFromMolFile(join(root, code, ligand_filename))\n",
    "    atoms1 = [b.GetBeginAtomIdx() for b in m.GetBonds()]\n",
    "    atoms2 = [b.GetEndAtomIdx() for b in m.GetBonds()]    \n",
    "    # Edge attributes: distance; SINGLE; DOUBLE; TRIPLE; AROMATIC.\n",
    "    edge_weights= []\n",
    "    coords = m.GetConformers()[0].GetPositions()  # Get a const reference to the vector of atom positions\n",
    "    for b in m.GetBonds():\n",
    "        if str(b.GetBondType()) == \"SINGLE\":\n",
    "            edge_weights.append(1)\n",
    "        elif str(b.GetBondType()) == \"DOUBLE\":\n",
    "            edge_weights.append(2)\n",
    "        elif str(b.GetBondType()) == \"TRIPLE\":\n",
    "            edge_weights.append(3)\n",
    "        else:\n",
    "            edge_weights.append(4)\n",
    "    edge_features = np.array(edge_weights) \n",
    "    # since the torch-geometric graphs are directed add reverse direction of edges\n",
    "    return np.array([atoms1 + atoms2, atoms2 + atoms1]), np.concatenate((edge_features, edge_features), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class PDBbindDataset(InMemoryDataset):\n",
    "    '''\n",
    "    PDBbindDataset Class: A dataset class that takes root directory, BPS features, and binding affinity values of \n",
    "    samples as input. For each complex, the class stores a graph where node features are precomputed BPS\n",
    "    features, edge features are bond types and target value is the equivalent binding affinity of the given complex.\n",
    "    Output: A list of graphs.\n",
    "    '''\n",
    "    def __init__(self, root, node_features,activity_csv,transform=None, pre_transform=None):\n",
    "        self.root = root\n",
    "        self.node_features = node_features\n",
    "        self.activity_csv = activity_csv\n",
    "        super(PDBbindDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.data, self.slices = torch.load(self.processed_paths[0])\n",
    "\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return [\"data.pt\"]\n",
    "\n",
    "    def download(self):\n",
    "        pass\n",
    "\n",
    "    def process(self):\n",
    "\n",
    "        self.node_data = dd.io.load(join(self.root, self.node_features))\n",
    "        # load csv with activity data and simlarity scores\n",
    "        self.activity = pd.read_csv(join(self.root, self.activity_csv))\n",
    "        # create lists of edges and edge descriptors \n",
    "        self.edge_indexes = {key: add_edges_list(self.root, key)[0] for key in self.activity.PDB }\n",
    "        self.edge_data = {key: add_edges_list(self.root, key)[1] for key in self.activity.PDB }\n",
    "        \n",
    "        # Read data into huge `Data` list.\n",
    "        data_list = [Data(x = torch.FloatTensor(self.node_data[key]),\n",
    "                          edge_index = torch.LongTensor(self.edge_indexes[key]),\n",
    "                          edge_attr = torch.FloatTensor(self.edge_data[key]),\n",
    "                          y = torch.FloatTensor([self.activity[self.activity.PDB == key].pk.iloc[0]])) for key in self.activity.PDB ]\n",
    "      \n",
    "        if self.pre_filter is not None:\n",
    "            data_list = [data for data in data_list if self.pre_filter(data)]\n",
    "\n",
    "        if self.pre_transform is not None:\n",
    "            data_list = [self.pre_transform(data) for data in data_list]\n",
    "        \n",
    "        data, slices = self.collate(data_list)\n",
    "        torch.save((data, slices), self.processed_paths[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "@torch.no_grad()\n",
    "def test_predictions(model, loader):\n",
    "    model.eval()\n",
    "    pred = []\n",
    "    true = []\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        pred += model(data).detach().cpu().numpy().tolist()\n",
    "        true += data.y.detach().cpu().numpy().tolist()\n",
    "    return pred, true"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Loading the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model1= Net() \n",
    "model2= Net() \n",
    "model3= Net() \n",
    "model4= Net() \n",
    "model1.load_state_dict(torch.load('GNNs_fusion_rand_split.pt'))\n",
    "model2.load_state_dict(torch.load('GNNs_fusion_TM_split.pt'))\n",
    "model3.load_state_dict(torch.load('GNNs_fusion_lig2lig_split.pt'))\n",
    "model4.load_state_dict(torch.load('GNNs_fusion_interaction_split.pt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results on the CSAR sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_test_directory = \"set1\" \n",
    "path_to_test_complexes_features =  \"set1.h5\"\n",
    "path_to_test_complexes_true_affinity = \"QSAR_NRC_HiQ_Set.csv\"\n",
    "set1 = PDBbindDataset(path_to_test_directory\n",
    "                        , path_to_test_complexes_features\n",
    "                        , path_to_test_complexes_true_affinity)\n",
    "test_loader = DataLoader(set1, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred, true = test_predictions(model1, test_loader)\n",
    "pred2, _ = test_predictions(model2, test_loader)\n",
    "pred3, _ = test_predictions(model3, test_loader)\n",
    "pred4, _ = test_predictions(model4, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results on CSAR_set1\n",
      "RMSE Random split:  1.22\n",
      "RMSE PPS:  1.38\n",
      "RMSE LLS:  1.26\n",
      "RMSE CCS:  1.2700000000000002\n"
     ]
    }
   ],
   "source": [
    "print(\"Results on CSAR_set1\")\n",
    "print(\"RMSE Random split: \",np.sqrt(mean_squared_error(pred,true)))\n",
    "print(\"RMSE PPS: \",np.sqrt(mean_squared_error(pred2,true)))\n",
    "print(\"RMSE LLS: \",np.sqrt(mean_squared_error(pred3,true)))\n",
    "print(\"RMSE CCS: \",np.sqrt(mean_squared_error(pred4,true)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_test_directory = \"set2\"  \n",
    "path_to_test_complexes_features =  \"set2.h5\"\n",
    "path_to_test_complexes_true_affinity = \"QSAR_NRC_HiQ_Set.csv\"\n",
    "set2 = PDBbindDataset(path_to_test_directory\n",
    "                        , path_to_test_complexes_features\n",
    "                        , path_to_test_complexes_true_affinity)\n",
    "test_loader = DataLoader(set2, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred, true = test_predictions(model1, test_loader)\n",
    "pred2, _ = test_predictions(model2, test_loader)\n",
    "pred3, _ = test_predictions(model3, test_loader)\n",
    "pred4, _ = test_predictions(model4, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results on CSAR_set2\n",
      "RMSE Random split:  1.0099999999999998\n",
      "RMSE PPS:  1.11\n",
      "RMSE LLS:  1.18\n",
      "RMSE CCS:  1.09\n"
     ]
    }
   ],
   "source": [
    "print(\"Results on CSAR_set2\")\n",
    "print(\"RMSE Random split: \",np.sqrt(mean_squared_error(pred,true)))\n",
    "print(\"RMSE PPS: \",np.sqrt(mean_squared_error(pred2,true)))\n",
    "print(\"RMSE LLS: \",np.sqrt(mean_squared_error(pred3,true)))\n",
    "print(\"RMSE CCS: \",np.sqrt(mean_squared_error(pred4,true)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
